{
  "metadata": {
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3",
      "language": "python"
    },
    "language_info": {
      "name": "python",
      "version": "3.11.11",
      "mimetype": "text/x-python",
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "pygments_lexer": "ipython3",
      "nbconvert_exporter": "python",
      "file_extension": ".py"
    },
    "colab": {
      "provenance": [],
      "gpuType": "T4",
      "include_colab_link": true
    },
    "accelerator": "GPU",
    "kaggle": {
      "accelerator": "nvidiaTeslaT4",
      "dataSources": [],
      "dockerImageVersionId": 31041,
      "isInternetEnabled": true,
      "language": "python",
      "sourceType": "notebook",
      "isGpuEnabled": true
    }
  },
  "nbformat_minor": 0,
  "nbformat": 4,
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/Baijiong-Lin/LoRA-Torch/blob/main/examples/Finetune_open_clip_with_LoRA_Torch_on_CIFAR10.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### This example demonstrates how to apply LoRA-Torch to ``nn.MultiheadAttention`` in OpenCLIP. We greatly appreciate [Viet Q. Vo](https://vietvo89.github.io/)'s valuable contribution."
      ],
      "metadata": {
        "id": "Yhxyu7vivWl2"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install open-clip-torch\n",
        "!pip install git+https://github.com/Baijiong-Lin/LoRA-Torch"
      ],
      "metadata": {
        "id": "753R63XXhzqE",
        "outputId": "ad3892df-aeb6-4ebe-b7ae-2400535ec3ab",
        "trusted": true,
        "execution": {
          "iopub.status.busy": "2025-06-09T01:09:26.555991Z",
          "iopub.execute_input": "2025-06-09T01:09:26.556177Z",
          "iopub.status.idle": "2025-06-09T01:10:48.370674Z",
          "shell.execute_reply.started": "2025-06-09T01:09:26.556160Z",
          "shell.execute_reply": "2025-06-09T01:10:48.369896Z"
        },
        "scrolled": true,
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Requirement already satisfied: open-clip-torch in /usr/local/lib/python3.11/dist-packages (2.32.0)\n",
            "Requirement already satisfied: torch>=1.9.0 in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (2.6.0+cu124)\n",
            "Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (0.21.0+cu124)\n",
            "Requirement already satisfied: regex in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (2024.11.6)\n",
            "Requirement already satisfied: ftfy in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (6.3.1)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (4.67.1)\n",
            "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (0.32.4)\n",
            "Requirement already satisfied: safetensors in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (0.5.3)\n",
            "Requirement already satisfied: timm in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (1.0.15)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (3.18.0)\n",
            "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (4.14.0)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (3.5)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (3.1.6)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (2025.3.2)\n",
            "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.4.127)\n",
            "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.4.127)\n",
            "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.4.127)\n",
            "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (9.1.0.70)\n",
            "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.4.5.8)\n",
            "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (11.2.1.3)\n",
            "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (10.3.5.147)\n",
            "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (11.6.1.9)\n",
            "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.3.1.170)\n",
            "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (0.6.2)\n",
            "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (2.21.5)\n",
            "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.4.127)\n",
            "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.4.127)\n",
            "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (3.2.0)\n",
            "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (1.13.1)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=1.9.0->open-clip-torch) (1.3.0)\n",
            "Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from ftfy->open-clip-torch) (0.2.13)\n",
            "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub->open-clip-torch) (24.2)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub->open-clip-torch) (6.0.2)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub->open-clip-torch) (2.32.3)\n",
            "Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub->open-clip-torch) (1.1.2)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision->open-clip-torch) (2.0.2)\n",
            "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision->open-clip-torch) (11.2.1)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch>=1.9.0->open-clip-torch) (3.0.2)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub->open-clip-torch) (3.4.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub->open-clip-torch) (3.10)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub->open-clip-torch) (2.4.0)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub->open-clip-torch) (2025.4.26)\n",
            "Collecting git+https://github.com/Baijiong-Lin/LoRA-Torch\n",
            "  Cloning https://github.com/Baijiong-Lin/LoRA-Torch to /tmp/pip-req-build-1a5zm6vx\n",
            "  Running command git clone --filter=blob:none --quiet https://github.com/Baijiong-Lin/LoRA-Torch /tmp/pip-req-build-1a5zm6vx\n",
            "  Resolved https://github.com/Baijiong-Lin/LoRA-Torch to commit 3b6f10a3bdebfb0da1abeb4c265f914ed06759e4\n",
            "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ]
        }
      ],
      "execution_count": 1
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "from torchvision import transforms\n",
        "from torch.utils.data import DataLoader\n",
        "import open_clip\n",
        "import loratorch as lora\n",
        "from tqdm import tqdm"
      ],
      "metadata": {
        "id": "T9hDV9jQiU0L",
        "trusted": true,
        "execution": {
          "iopub.status.busy": "2025-06-09T01:10:48.376803Z",
          "iopub.execute_input": "2025-06-09T01:10:48.376979Z",
          "iopub.status.idle": "2025-06-09T01:11:15.903375Z",
          "shell.execute_reply.started": "2025-06-09T01:10:48.376955Z",
          "shell.execute_reply": "2025-06-09T01:11:15.902530Z"
        }
      },
      "outputs": [],
      "execution_count": 2
    },
    {
      "cell_type": "markdown",
      "source": [
        "### A. Load Pre-trained Model"
      ],
      "metadata": {
        "id": "kDBpUYgYjrGw"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai')\n",
        "model = model.cuda()\n",
        "tokenizer = open_clip.get_tokenizer('ViT-B-32')"
      ],
      "metadata": {
        "trusted": true,
        "execution": {
          "iopub.status.busy": "2025-06-09T01:11:15.904141Z",
          "iopub.execute_input": "2025-06-09T01:11:15.904827Z",
          "iopub.status.idle": "2025-06-09T01:11:20.771415Z",
          "shell.execute_reply.started": "2025-06-09T01:11:15.904798Z",
          "shell.execute_reply": "2025-06-09T01:11:20.770632Z"
        },
        "id": "1bqndlUex2HS",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "9bf32ef6-6a5f-45c0-a8ea-64fb9381a6dc"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n",
            "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
            "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
            "You will be able to reuse this secret in all of your notebooks.\n",
            "Please note that authentication is recommended but still optional to access public models or datasets.\n",
            "  warnings.warn(\n",
            "/usr/local/lib/python3.11/dist-packages/open_clip/factory.py:388: UserWarning: These pretrained weights were trained with QuickGELU activation but the model config does not have that enabled. Consider using a model config with a \"-quickgelu\" suffix or enable with a flag.\n",
            "  warnings.warn(\n"
          ]
        }
      ],
      "execution_count": 3
    },
    {
      "cell_type": "code",
      "source": [
        "# prompt: count trainable parameters of model?\n",
        "\n",
        "def count_parameters(model):\n",
        "    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
        "    vision_params = sum(p.numel() for p in model.visual.transformer.parameters() if p.requires_grad)\n",
        "    text_params = sum(p.numel() for p in model.transformer.parameters() if p.requires_grad)\n",
        "    embed_params = sum(p.numel() for p in model.token_embedding.parameters() if p.requires_grad)\n",
        "    total_params = sum(p.numel() for p in model.parameters())\n",
        "    print(f\"Total parameters: {total_params:,}\")\n",
        "    print(f\"Trainable parameters - Full model: {trainable_params:,}\")\n",
        "    print(f\"Trainable parameters - Vision: {vision_params:,}\")\n",
        "    print(f\"Trainable parameters - Text: {text_params:,}\")\n",
        "    print(f\"Trainable parameters - embedding: {embed_params:,}\")"
      ],
      "metadata": {
        "id": "VQoMyCMio4Kg",
        "trusted": true,
        "execution": {
          "iopub.status.busy": "2025-06-09T01:21:06.763946Z",
          "iopub.execute_input": "2025-06-09T01:21:06.764449Z",
          "iopub.status.idle": "2025-06-09T01:21:06.769661Z",
          "shell.execute_reply.started": "2025-06-09T01:21:06.764428Z",
          "shell.execute_reply": "2025-06-09T01:21:06.768894Z"
        }
      },
      "outputs": [],
      "execution_count": 4
    },
    {
      "cell_type": "code",
      "source": [
        "print('Original model before adding lora')\n",
        "count_parameters(model)"
      ],
      "metadata": {
        "trusted": true,
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "eJXm5sjmx2HW",
        "outputId": "77426425-634d-468e-e379-21db4d9312ec"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Original model before adding lora\n",
            "Total parameters: 151,277,313\n",
            "Trainable parameters - Full model: 151,277,313\n",
            "Trainable parameters - Vision: 85,054,464\n",
            "Trainable parameters - Text: 37,828,608\n",
            "Trainable parameters - embedding: 25,296,896\n"
          ]
        }
      ],
      "execution_count": 5
    },
    {
      "cell_type": "markdown",
      "source": [
        "### B. Load CIFAR-10"
      ],
      "metadata": {
        "id": "OMHX0z87i5_9"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# prompt: load cifar10 dataset\n",
        "\n",
        "from torchvision.datasets import CIFAR10\n",
        "\n",
        "train_dataset = CIFAR10(\n",
        "    root=\"./data\", train=True, download=True,\n",
        "    transform=preprocess\n",
        ")\n",
        "test_dataset = CIFAR10(\n",
        "    root=\"./data\", train=False, download=True,\n",
        "    transform=preprocess\n",
        ")\n",
        "\n",
        "batch_size_train = 256\n",
        "batch_size_test = 256\n",
        "train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, num_workers=4)\n",
        "test_loader = DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False, num_workers=4)\n"
      ],
      "metadata": {
        "id": "mA1W_9IIiA_P",
        "outputId": "5d786a0d-d462-4034-ad85-3ce88c030ca8",
        "trusted": true,
        "execution": {
          "iopub.status.busy": "2025-06-09T01:11:20.772338Z",
          "iopub.execute_input": "2025-06-09T01:11:20.772628Z",
          "iopub.status.idle": "2025-06-09T01:11:26.195815Z",
          "shell.execute_reply.started": "2025-06-09T01:11:20.772603Z",
          "shell.execute_reply": "2025-06-09T01:11:26.195198Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py:624: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
            "  warnings.warn(\n"
          ]
        }
      ],
      "execution_count": 6
    },
    {
      "cell_type": "markdown",
      "source": [
        "### C. Fine-tune OpenCLIP with LoRA\n",
        "\n",
        "_Note:_\n",
        "\n",
        "Please make sure ``loratorch.MultiheadAttention`` uses the same input parameter values as [`nn.MultiheadAttention`](https://docs.pytorch.org/docs/2.6/generated/torch.nn.MultiheadAttention.html#multiheadattention).\n",
        "\n",
        "For exmaple, the default value for batch_first in `nn.MultiheadAttention` is `False`, but `open_clip` sets it to `True` in some `attn` layers. The discussion of this can be found [here](https://github.com/Baijiong-Lin/LoRA-Torch/issues/6#issuecomment-2954122864).\n",
        "\n",
        "The best way of employing `loratorch.MultiheadAttention` is the following:\n",
        "```python\n",
        "lora_multihead = lora.MultiheadAttention(r=r,\n",
        "                        lora_alpha=lora_alpha,\n",
        "                        enable_lora=enable_lora,\n",
        "                        embed_dim=multihead.embed_dim,\n",
        "                        num_heads=multihead.num_heads,\n",
        "                        dropout=multihead.dropout,\n",
        "                        bias=True if hasattr(multihead, \"in_proj_bias\") else False,\n",
        "                        add_bias_kv=False if multihead.bias_k==None else True,\n",
        "                        add_zero_attn=multihead.add_zero_attn,\n",
        "                        kdim=multihead.kdim,\n",
        "                        vdim=multihead.vdim,\n",
        "                        batch_first=multihead.batch_first)\n",
        "```"
      ],
      "metadata": {
        "id": "Ahjx5_MPlbvU"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### Apply LoRA to `attn` and `mlp`"
      ],
      "metadata": {
        "id": "Y1sLkX5tDZGz"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def apply_lora_attn_mlp(model, encoder_type='visual', rank=16, lora_alpha=32, mlp=True, attn=True):\n",
        "    if encoder_type == 'visual':\n",
        "        encoder = model.visual.transformer\n",
        "    elif encoder_type == 'text':\n",
        "        encoder = model.transformer\n",
        "    else:\n",
        "        raise ValueError(\"Invalid encoder_type. Choose 'visual' or 'text'.\")\n",
        "\n",
        "    enable_lora=['q', 'k', 'v', 'o']\n",
        "    for i, resblock in enumerate(encoder.resblocks):\n",
        "        if hasattr(resblock, 'attn') and attn:\n",
        "            multihead = resblock.attn\n",
        "            lora_multihead = lora.MultiheadAttention(r=rank,\n",
        "                                    lora_alpha=lora_alpha,\n",
        "                                    enable_lora=enable_lora,\n",
        "                                    embed_dim=multihead.embed_dim,\n",
        "                                    num_heads=multihead.num_heads,\n",
        "                                    dropout=multihead.dropout,\n",
        "                                    bias=True if hasattr(multihead, \"in_proj_bias\") else False,\n",
        "                                    add_bias_kv=False if multihead.bias_k==None else True,\n",
        "                                    add_zero_attn=multihead.add_zero_attn,\n",
        "                                    kdim=multihead.kdim,\n",
        "                                    vdim=multihead.vdim,\n",
        "                                    batch_first=multihead.batch_first)\n",
        "            lora_multihead.load_state_dict(multihead.state_dict(), strict=False)\n",
        "            resblock.attn = lora_multihead\n",
        "\n",
        "        if hasattr(resblock, 'mlp') and mlp:\n",
        "            old_mlp_fc=resblock.mlp.c_fc\n",
        "            old_mlp_proj=resblock.mlp.c_proj\n",
        "            new_mlp_fc = lora.Linear(\n",
        "                old_mlp_fc.in_features,\n",
        "                old_mlp_fc.out_features,\n",
        "                bias=True if hasattr(old_mlp_fc, \"bias\") else False,\n",
        "                r=rank,\n",
        "                lora_alpha=lora_alpha,\n",
        "            )\n",
        "            new_mlp_proj = lora.Linear(\n",
        "                old_mlp_proj.in_features,\n",
        "                old_mlp_proj.out_features,\n",
        "                bias=True if hasattr(old_mlp_proj, \"bias\") else False,\n",
        "                r=rank,\n",
        "                lora_alpha=lora_alpha,\n",
        "            )\n",
        "            new_mlp_fc.load_state_dict(old_mlp_fc.state_dict(),strict=False)\n",
        "            new_mlp_proj.load_state_dict(old_mlp_proj.state_dict(),strict=False)\n",
        "            resblock.mlp.c_fc = new_mlp_fc\n",
        "            resblock.mlp.c_proj = new_mlp_proj\n",
        "\n",
        "    lora.mark_only_lora_as_trainable(model)\n",
        "    return model"
      ],
      "metadata": {
        "trusted": true,
        "execution": {
          "iopub.status.busy": "2025-06-09T01:20:16.374517Z",
          "iopub.execute_input": "2025-06-09T01:20:16.375239Z",
          "iopub.status.idle": "2025-06-09T01:20:16.382721Z",
          "shell.execute_reply.started": "2025-06-09T01:20:16.375208Z",
          "shell.execute_reply": "2025-06-09T01:20:16.382180Z"
        },
        "id": "UQ8c81GOx2HZ"
      },
      "outputs": [],
      "execution_count": 7
    },
    {
      "cell_type": "code",
      "source": [
        "apply_lora_attn_mlp(model, encoder_type='visual', rank=16, lora_alpha=32, mlp=True, attn=True)\n",
        "tokenizer = open_clip.get_tokenizer('ViT-B-32')"
      ],
      "metadata": {
        "trusted": true,
        "execution": {
          "iopub.status.busy": "2025-06-09T01:20:17.384221Z",
          "iopub.execute_input": "2025-06-09T01:20:17.384480Z",
          "iopub.status.idle": "2025-06-09T01:20:18.691620Z",
          "shell.execute_reply.started": "2025-06-09T01:20:17.384462Z",
          "shell.execute_reply": "2025-06-09T01:20:18.690864Z"
        },
        "id": "WThYJ3zzx2HZ"
      },
      "outputs": [],
      "execution_count": 8
    },
    {
      "cell_type": "code",
      "source": [
        "for name, param in model.visual.transformer.resblocks[0].named_parameters():\n",
        "    print(name, param.requires_grad)\n",
        "\n",
        "# after adding lora\n",
        "print(\"\\nAfter adding LoRA to Attn+MLP:\")\n",
        "count_parameters(model)"
      ],
      "metadata": {
        "trusted": true,
        "execution": {
          "iopub.status.busy": "2025-06-09T01:21:17.260662Z",
          "iopub.execute_input": "2025-06-09T01:21:17.260951Z",
          "iopub.status.idle": "2025-06-09T01:21:17.267985Z",
          "shell.execute_reply.started": "2025-06-09T01:21:17.260934Z",
          "shell.execute_reply": "2025-06-09T01:21:17.267184Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "DT_NoSaTx2HZ",
        "outputId": "2427de82-3257-4b98-fe21-64c9330ee75c"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "ln_1.weight False\n",
            "ln_1.bias False\n",
            "attn.in_proj_weight False\n",
            "attn.in_proj_bias False\n",
            "attn.o_lora_A True\n",
            "attn.o_lora_B True\n",
            "attn.qkv_lora_A True\n",
            "attn.qkv_lora_B True\n",
            "attn.out_proj.weight False\n",
            "attn.out_proj.bias False\n",
            "ln_2.weight False\n",
            "ln_2.bias False\n",
            "mlp.c_fc.weight False\n",
            "mlp.c_fc.bias False\n",
            "mlp.c_fc.w_lora_A True\n",
            "mlp.c_fc.w_lora_B True\n",
            "mlp.c_proj.weight False\n",
            "mlp.c_proj.bias False\n",
            "mlp.c_proj.w_lora_A True\n",
            "mlp.c_proj.w_lora_B True\n",
            "\n",
            "After adding LoRA to Attn+MLP:\n",
            "Total parameters: 153,636,609\n",
            "Trainable parameters - Full model: 2,359,296\n",
            "Trainable parameters - Vision: 2,359,296\n",
            "Trainable parameters - Text: 0\n",
            "Trainable parameters - embedding: 0\n"
          ]
        }
      ],
      "execution_count": 9
    },
    {
      "cell_type": "code",
      "source": [
        "# Tokenizer and text embeddings\n",
        "model.cuda()\n",
        "\n",
        "tokenizer = open_clip.get_tokenizer(\"ViT-B-32\")\n",
        "classnames = train_dataset.classes\n",
        "text_inputs = tokenizer([f\"a photo of a {label}\" for label in classnames]).cuda()\n",
        "with torch.no_grad():\n",
        "    text_features = model.encode_text(text_inputs)\n",
        "    text_features = text_features / text_features.norm(dim=-1, keepdim=True)\n",
        "\n",
        "# Optimizer\n",
        "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)"
      ],
      "metadata": {
        "trusted": true,
        "execution": {
          "iopub.status.busy": "2025-06-09T01:21:34.494110Z",
          "iopub.execute_input": "2025-06-09T01:21:34.494363Z",
          "iopub.status.idle": "2025-06-09T01:21:35.033731Z",
          "shell.execute_reply.started": "2025-06-09T01:21:34.494344Z",
          "shell.execute_reply": "2025-06-09T01:21:35.032959Z"
        },
        "id": "RdVjI-pvx2HZ"
      },
      "outputs": [],
      "execution_count": 10
    },
    {
      "cell_type": "code",
      "source": [
        "# Train loop\n",
        "model.train()\n",
        "for epoch in range(3):\n",
        "    total_loss = 0\n",
        "    correct = 0\n",
        "    total = 0\n",
        "    for images, labels in tqdm(train_loader):\n",
        "        images, labels = images.cuda(), labels.cuda()\n",
        "        image_features = model.encode_image(images)\n",
        "        image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
        "        logits = image_features @ text_features.t()\n",
        "        loss = nn.CrossEntropyLoss()(logits, labels)\n",
        "\n",
        "        optimizer.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "        total_loss += loss.item()\n",
        "\n",
        "        preds = logits.argmax(dim=1)\n",
        "        correct += (preds == labels).sum().item()\n",
        "        total += labels.size(0)\n",
        "        # (!!!) reregister model param to ensure they are in model.state_dict() and model.parameters()\n",
        "        # (!!!) Without this line, the performance does not be affected but you will find that some weights are missing in model.state_dict() and model.parameters()\n",
        "        lora.register_model_param_after_backward(model)\n",
        "\n",
        "    acc = correct / total\n",
        "    print(f\"Epoch {epoch+1}: Loss={total_loss:.4f}, Accuracy={acc:.4f}\")"
      ],
      "metadata": {
        "outputId": "153083d4-3f77-4b1d-bc7e-a9798cfb1139",
        "trusted": true,
        "execution": {
          "iopub.status.busy": "2025-06-09T01:21:35.369183Z",
          "iopub.execute_input": "2025-06-09T01:21:35.369507Z",
          "iopub.status.idle": "2025-06-09T01:54:15.576408Z",
          "shell.execute_reply.started": "2025-06-09T01:21:35.369485Z",
          "shell.execute_reply": "2025-06-09T01:54:15.575561Z"
        },
        "id": "UCgTKj1qx2HZ",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 196/196 [06:43<00:00,  2.06s/it]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch 1: Loss=397.7533, Accuracy=0.9529\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 196/196 [06:40<00:00,  2.04s/it]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch 2: Loss=387.0708, Accuracy=0.9796\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 196/196 [06:39<00:00,  2.04s/it]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch 3: Loss=386.2361, Accuracy=0.9901\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "\n"
          ]
        }
      ],
      "execution_count": 11
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "trusted": true,
        "id": "pMS8k5sXx2Ha"
      },
      "outputs": [],
      "execution_count": 11
    }
  ]
}